#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0

####################### IsolatePass.py #########################################
#
# Copyright 2025 The IBM Research Authors.
#
################################################################################
#
# This script is used to identify a specific pass from a file listing several
# passes, e.g. when using onnx-mlir -mlir-print-after-all
#
################################################################################

import argparse
import os
import re

# global variables
pass_name_to_id = {}
id_to_pass_name = []
pass_listing = []
debug = 0
# Max length of a single line. If there are very long constants, truncate them.
max_line_length = 800


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-i",
        "--input",
        help="Provide the NAME input log file to parse, e.g. a file "
        "generated by an 'onnx-mlir --mlir-print-after-all' command.",
        metavar="NAME",
    )
    parser.add_argument(
        "-p",
        "--pass-name",
        help="Name of the pass to list. When regexp matches "
        "multiple passes, use -n option to indicate which match to select.",
        metavar="REGEX",
    )
    parser.add_argument(
        "-n",
        "--num",
        help="Id NUM of the pass to list when -p is not specified. "
        "Otherwise, list the NUM(th) match returned by the -p regex string.",
        metavar="NUM",
    )
    parser.add_argument(
        "-a",
        "--after",
        help="If STR is a number, N, print the pass Nth after the pass indicated by the  "
        "-p or -n options. Number N can be negative. Otherwise, print the next pass that "
        "matches the regex STR.",
        metavar="STR",
    )
    parser.add_argument(
        "-l",
        "--list-passes",
        action="store_true",
        help="List the name of every passes. Default on when missing the -p or -n options.",
    )
    return parser.parse_args()


def usage(error_message):
    print("Error:", error_message)
    print("Use -h / --help for more information.")
    exit(1)


def extract_ir_pass_name(text):
    pattern = r"// -----// IR Dump After(.*?) //----- //"
    match = re.search(pattern, text)
    if match:
        return match.group(1).strip()
    else:
        return None


def process_line(str, length=max_line_length):
    # Cap length.
    if len(str) > length:
        str = str[:length]
    # Remove end of line,
    if str.endswith("\n"):
        str = str[:-1]
    # Ensure even number of '"".
    if str.count('"') % 2 == 1:
        # Add '"'.
        str += '"'
    return str + "\n"


def scan_listing(filename, print_list_name):
    global pass_name_to_id, id_to_pass_name, pass_listing
    pass_name_to_id = {}
    id_to_pass_name = []
    pass_listing = []

    current_listing = []
    current_name = ""
    try:
        with open(filename, "r") as file:
            for line in file:
                if re.match(r"^onnx-mlir ", line):
                    print(f"// Command:\n// {line}\n")
                    continue
                pass_name = extract_ir_pass_name(line)
                if pass_name:
                    if current_name:
                        if current_listing:
                            # Save current listing.
                            id = len(pass_listing)
                            pass_listing.append("".join(current_listing))
                            id_to_pass_name.append(current_name)
                            if current_name in pass_name_to_id:
                                pass_name_to_id[current_name].append(id)
                            else:
                                pass_name_to_id[current_name] = [id]
                            # Print info if requested.
                            if print_list_name:
                                print(f"{id}: {current_name}")
                    # Save new current name
                    current_name = pass_name
                    current_listing = []
                current_listing.append(process_line(line))
    except FileNotFoundError:
        print(f"Error: The file '{filename}' was not found.")
        raise
    except IOError as e:
        print(f"Error: Could not read file '{filename}'. Reason: {e}")
        raise


def locate_pass(name, num):
    global pass_name_to_id, pass_listing

    ids = []
    regex = re.compile(name)
    for key in pass_name_to_id:
        if regex.search(key):
            curr_ids = pass_name_to_id[key]
            ids.extend(curr_ids)
            if debug:
                print(f" Matched key: {key} -> Value: {id}")
    num_ids = len(ids)
    if num_ids == 0:
        usage(
            f"pass {name} not found; please check the name of the pass you are looking for."
        )
    if num_ids > 1:
        if num:
            n = int(num)
            if n < 0 or n >= num_ids:
                usage(
                    f"Provided a -n {n} number that is not in the range [0..{num_ids})."
                )
            return ids[n]
        print(f"Pass {name} is ambiguous and matched with the following ids:")
        print("  ", ids)
        print(
            "  Please choose which one with the -n <id> option to select a unique pass."
        )
        exit(1)
    return ids[0]


def print_pass(filename, id, after, pass_name=None):
    global pass_name_to_id, id_to_pass_name, pass_listing
    n = 0
    if after:
        if re.fullmatch(r"[+-]?\d+", after) is not None:
            # After is a number.
            n = int(after)
        else:
            # After is a pass name, locate it.
            regex = re.compile(after)
            n = 1
            while True:
                if id + n >= len(id_to_pass_name):
                    usage(f"Did not find pass {after} after pass {pass_name}")
                if regex.search(id_to_pass_name[id + n]):
                    # Found it
                    break
                n += 1
        id += n
    if id < 0 or id >= len(pass_listing):
        print(f"Out of bound id {id}, should be in [0..{len(pass_listing)}) range.")
        exit(1)
    message = f"// Printing pass with id {id}"
    if pass_name:
        if n != 0:
            if n > 0:
                message += f", {n}(th) pass(s) after {pass_name}"
            else:
                message += f", {-n}(th) pass(s) before {pass_name}"
        else:
            message += f" with name {pass_name}"
    message += f' from file "{filename}".'
    print(f"{message}\n\n", pass_listing[id], f"\n\n{message}")


# Process arguments
args = get_args()
if args.input == None:
    usage("missing file namem")
if args.pass_name == None and args.num == None:
    args.list_passes = True

# Perform hte listing.
scan_listing(args.input, args.list_passes)
if args.list_passes:
    exit(0)
if args.pass_name:
    id = locate_pass(args.pass_name, args.num)
    print_pass(args.input, id, args.after, args.pass_name)
elif args.num:
    print_pass(args.input, int(args.num), args.after)
